This project explores a new release of lime, an R port of a python library that explains machine learning models on a per-observatoin basis. For this project we will use a pre-trained ImageNet model available from keras to explore lime. The vgg16 model is an image classification model, which attempts to classify pictures into 1000 different categories.
Load Library and Model
library(keras)
library(lime)
library(magick)
## Linking to ImageMagick 6.9.9.14
## Enabled features: cairo, freetype, fftw, ghostscript, lcms, pango, rsvg, webp
## Disabled features: fontconfig, x11
library(abind)
model <- application_vgg16(
weights = "imagenet",
include_top = TRUE
)
model
## Model
## Model: "vgg16"
## ___________________________________________________________________________
## Layer (type) Output Shape Param #
## ===========================================================================
## input_1 (InputLayer) [(None, 224, 224, 3)] 0
## ___________________________________________________________________________
## block1_conv1 (Conv2D) (None, 224, 224, 64) 1792
## ___________________________________________________________________________
## block1_conv2 (Conv2D) (None, 224, 224, 64) 36928
## ___________________________________________________________________________
## block1_pool (MaxPooling2D) (None, 112, 112, 64) 0
## ___________________________________________________________________________
## block2_conv1 (Conv2D) (None, 112, 112, 128) 73856
## ___________________________________________________________________________
## block2_conv2 (Conv2D) (None, 112, 112, 128) 147584
## ___________________________________________________________________________
## block2_pool (MaxPooling2D) (None, 56, 56, 128) 0
## ___________________________________________________________________________
## block3_conv1 (Conv2D) (None, 56, 56, 256) 295168
## ___________________________________________________________________________
## block3_conv2 (Conv2D) (None, 56, 56, 256) 590080
## ___________________________________________________________________________
## block3_conv3 (Conv2D) (None, 56, 56, 256) 590080
## ___________________________________________________________________________
## block3_pool (MaxPooling2D) (None, 28, 28, 256) 0
## ___________________________________________________________________________
## block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160
## ___________________________________________________________________________
## block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808
## ___________________________________________________________________________
## block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808
## ___________________________________________________________________________
## block4_pool (MaxPooling2D) (None, 14, 14, 512) 0
## ___________________________________________________________________________
## block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808
## ___________________________________________________________________________
## block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808
## ___________________________________________________________________________
## block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808
## ___________________________________________________________________________
## block5_pool (MaxPooling2D) (None, 7, 7, 512) 0
## ___________________________________________________________________________
## flatten (Flatten) (None, 25088) 0
## ___________________________________________________________________________
## fc1 (Dense) (None, 4096) 102764544
## ___________________________________________________________________________
## fc2 (Dense) (None, 4096) 16781312
## ___________________________________________________________________________
## predictions (Dense) (None, 1000) 4097000
## ===========================================================================
## Total params: 138,357,544
## Trainable params: 138,357,544
## Non-trainable params: 0
## ___________________________________________________________________________
We will pass the following picture of a kitten into the model.
img <- image_read('https://www.data-imaginist.com/assets/images/kitten.jpg')
img_path <- file.path(tempdir(), 'kitten.jpg')
image_write(img, img_path)
plot(as.raster(img))
Next we need to prepare the data for the model by formatting this image as tensors.
image_prep <- function(x) {
arrays <- lapply(x, function(path) {
img <- image_load(path, target_size = c(224,224))
x <- image_to_array(img)
x <- array_reshape(x, c(1, dim(x)))
x <- imagenet_preprocess_input(x)
})
do.call(abind::abind, c(arrays, list(along = 1)))
}
explainer <- lime(img_path, model, image_prep)
This will show us what the model believes this picture to be, which is basically some sort of cat. Obviously this is correct.
res <- predict(model, image_prep(img_path))
imagenet_decode_predictions(res)
## [[1]]
## class_name class_description score
## 1 n02124075 Egyptian_cat 0.464922875
## 2 n02123045 tabby 0.157048911
## 3 n02123159 tiger_cat 0.104360037
## 4 n02127052 lynx 0.026472222
## 5 n03793489 mouse 0.009612675
Now using lime, we can explore why the model came to its decisions.
model_labels <- readRDS(system.file('extdata', 'imagenet_labels.rds', package = 'lime'))
explainer <- lime(img_path, as_classifier(model, model_labels), image_prep)
plot_superpixels(img_path)
We can break down images into superpixels which are basically an area of similiar pixels. This allows the model to determine if the area is important and how to classify the image. These’s areas need to be big enough to capture the important parts of the picture, but not so big that the model cannot recognize the image. The following uses more superpixels.
plot_superpixels(img_path, n_superpixels = 200, weight = 40)
The following will plot the explanations onto the image.
explanation <- explain(img_path, explainer, n_labels = 2, n_features = 20)
plot_image_explanation(explanation)
The following will block anything that the model did not consider part of is classification.
plot_image_explanation(explanation, display = 'block', threshold = 0.01)
The following will look at the areas that agree with and contradict the classificaation.
plot_image_explanation(explanation, threshold = 0, show_negative = TRUE, fill_alpha = 0.6)
Here I will continue to explore this lime model by tweeking the number of superpixels and labels with a picture of my own cat.
img <- image_read('CatsVicenti2.jpg')
img_path <- file.path(tempdir(), 'catVic.jpg')
image_write(img, img_path)
plot(as.raster(img))
Prepare the image and look at the default superpixels.
image_prep <- function(x) {
arrays <- lapply(x, function(path) {
img <- image_load(path, target_size = c(224,224))
x <- image_to_array(img)
x <- array_reshape(x, c(1, dim(x)))
x <- imagenet_preprocess_input(x)
})
do.call(abind::abind, c(arrays, list(along = 1)))
}
explainer <- lime(img_path, model, image_prep)
res <- predict(model, image_prep(img_path))
imagenet_decode_predictions(res)
## [[1]]
## class_name class_description score
## 1 n02123394 Persian_cat 0.44745401
## 2 n02328150 Angora 0.05786140
## 3 n02127052 lynx 0.04759105
## 4 n02112018 Pomeranian 0.04175223
## 5 n03014705 chest 0.04150921
model_labels <- readRDS(system.file('extdata', 'imagenet_labels.rds', package = 'lime'))
explainer <- lime(img_path, as_classifier(model, model_labels), image_prep)
plot_superpixels(img_path)
Up the superpixels to 300 instead of 200.
plot_superpixels(img_path, n_superpixels = 300, weight = 40)
Plot explanation of up to 5 labels instead of 2.
explanation <- explain(img_path, explainer, n_labels = 5, n_features = 20)
plot_image_explanation(explanation)
The following will block anything that the model did not consider part of is classification.
plot_image_explanation(explanation, display = 'block', threshold = 0.01)
The following will look at the areas that agree with and contradict the classificaation.
plot_image_explanation(explanation, threshold = 0, show_negative = TRUE, fill_alpha = 0.6)
In conlusion, we find that this model classified the cat first, but after looking at the fifth label, it was also able to classify my chest/dresser.